EchoSpike Predictive Plasticity¶

In [ ]:
import matplotlib.pyplot as plt
import os
from utils import  get_accuracy, get_samples, train_out_proj_fast, train_out_proj_closed_form
from main import Args
from data import load_SHD
from model import EchoSpike, simple_out
import numpy as np
from data import augment_shd
import torch
import seaborn as sns
from scipy.signal import savgol_filter
from tqdm.notebook import trange
from matplotlib import pyplot
pyplot.rcParams['figure.dpi'] = 600
import pickle
torch.manual_seed(0)
color_list = sns.color_palette('muted')
device = 'cpu'
batch_size = 64
folder = 'models/'
model_name = folder + 'shd_4layer_withckpts_noaugment.pt'
with open(model_name[:-3] + '_args.pkl', 'rb') as f:
    args = pickle.load(f)
# args = Args()
online = args.online
print(vars(args))
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
  warnings.warn(_BETA_TRANSFORMS_WARNING)
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
  warnings.warn(_BETA_TRANSFORMS_WARNING)
{'model_name': 'shd_4layer_withckpts_noaugment', 'dataset': 'shd', 'online': True, 'device': 'cuda', 'recurrency_type': 'none', 'lr': 0.0001, 'epochs': 1000, 'augment': False, 'batch_size': 128, 'n_hidden': [450, 450, 450, 450], 'inp_thr': 0.05, 'c_y': [1.5, -1.5], 'n_inputs': 700, 'n_outputs': 20, 'n_time_bins': 100, 'beta': 0.95}

Dataset¶

Spiking Heidelberg Digits

In [ ]:
#train_loader, test_loader = load_PMNIST(n_time_bins, scale=0.9, patches=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
n_time_bins = 100
train_loader, test_loader = load_SHD(batch_size=batch_size) #load_NMNIST(n_time_bins, batch_size=batch_size)
# Plot Example(s)
for i in range(1):
    frames, target = train_loader.next_item(-1, contrastive=True)
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(frames.squeeze(1).T)
    # plt.colorbar()
    print(frames.shape, target)
plt.axis('on')
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.y = torch.tensor(y)
torch.Size([100, 1, 700]) tensor([4.])
Out[ ]:
(-0.5, 99.5, 699.5, -0.5)

Load pretrained model¶

In [ ]:
SNN = EchoSpike(args.n_inputs, args.n_hidden, beta=args.beta, c_y=args.c_y, device=device, recurrency_type=args.recurrency_type, online=args.online).to(device)
SNN.load_state_dict(torch.load(model_name, map_location=device))
# train(SNN, train_loader, args.epochs, device, args.model_name,
                            # batch_size=args.batch_size, online=args.online, lr=1e-8, augment=args.augment)
from_epoch = 0
echo_train_loss = torch.load(model_name[:-3]+'_loss_hist.pt', map_location='cpu')[int(from_epoch*len(train_loader)/args.batch_size):]
print(echo_train_loss.shape)
for i in range(echo_train_loss.shape[-1]):
    plt.plot(from_epoch+(args.batch_size*np.arange(echo_train_loss.shape[0])/len(train_loader)), savgol_filter(echo_train_loss[:,i], 99, 1), color=color_list[i])
plt.legend([f'layer {i+1}' for i in range(len(SNN.layers))])
# no y ticks, because it's not really meaningful
plt.yticks([])
plt.xlabel('Epoch')
plt.ylabel('EchoSpike Loss')
torch.Size([63719, 4])
Out[ ]:
Text(0, 0.5, 'EchoSpike Loss')
In [ ]:
# plotting adaptive threshold and update rate for an example
# init_echo, label_0 = train_loader.next_item(-1, contrastive=True)
# sample_1, label_1 = train_loader.next_item(-1, contrastive=True)
# sample_2, label_2 = train_loader.next_item(label_1, contrastive=False)
print(label_0, label_1, label_2)
SNN.eval()
with torch.no_grad():
    # feed first sample to get initial activity
    for t in range(100):
        inp_activity = init_echo[t].mean(axis=-1)
        SNN(init_echo[t], torch.tensor(-1, device=device), inp_activity=inp_activity)
    SNN.reset(-1)
    # feed second sample to get the update rates and thresholds for contrastive case
    contrastive_thresholds = torch.zeros(100)
    contrastive_temp_sim = torch.zeros((len(SNN.layers), 100))
    for t in range(100):
        inp_activity = sample_1[t].mean(axis=-1)
        out_spk, mems, losses = SNN(sample_1[t], torch.tensor(-1, device=device), inp_activity=inp_activity)
        contrastive_thresholds[t] = inp_activity * args.c_y[1]
        contrastive_temp_sim[:, t] = losses
    SNN.reset(-1)
    # feed third sample to get the update rates and thresholds for predictive case
    predictive_thresholds = torch.zeros(100)
    predictive_temp_sim = torch.zeros((len(SNN.layers), 100))
    for t in range(100):
        inp_activity = sample_2[t].mean(axis=-1)
        out_spk, mems, losses = SNN(sample_1[t], torch.tensor(1, device=device), inp_activity=inp_activity)
        predictive_thresholds[t] = inp_activity * args.c_y[0]
        predictive_temp_sim[:, t] = -losses
    SNN.reset(1)
    # plot thresholds, with sample as background
    layer = 2
    fig, ax = plt.subplots(figsize=(10, 5))
    ax2  = ax.twinx()
    # imshow in background
    ax.imshow(sample_1.squeeze(1).T, aspect='auto', cmap='Reds')
    ax2.plot(-contrastive_temp_sim[layer], color='r', label='Negative Similarity Score')
    ax2.plot(contrastive_thresholds, color='r', linestyle='--', label='Contrastive Threshold')
    ax2.hlines(args.inp_thr*args.c_y[1], 0, 100, color='r', linestyle=':', label='Input Threshold (times c(-1))')
    # highlight regions where the thresholds are crossed
    argwhere = np.argwhere(np.logical_and((-contrastive_temp_sim[layer] < contrastive_thresholds).numpy(), contrastive_thresholds.numpy() < args.inp_thr*args.c_y[1]))
    for i in range(argwhere.shape[0]):
        ax2.axvspan(argwhere[i], argwhere[i]+1, color='r', alpha=0.2, lw=0)

    ax.yaxis.set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.yaxis.tick_left()
    ax2.yaxis.set_label_position('left')
    ax2.set_xlim(ax.get_xlim())
    # get rid of right margin
    ax2.margins(0)
    ax.set_xlabel('Timesteps')
    plt.ylabel('Thresholds & Similarity Score')
    plt.xlim(0, 100)
    plt.legend()
    # same for predictive
    fig, ax = plt.subplots(figsize=(10, 5))
    ax2  = ax.twinx()
    # imshow in background
    ax.imshow(sample_2.squeeze(1).T, aspect='auto', cmap='Blues')
    ax2.plot(predictive_temp_sim[layer], color='b', label='Similarity Score')
    ax2.plot(predictive_thresholds, color='b', linestyle='--', label='Predictive Threshold')
    ax2.hlines(args.inp_thr*args.c_y[0], 0, 100, color='b', linestyle=':', label='Input Threshold (times c(1))')
    # highlight regions where the thresholds are crossed
    argwhere = np.argwhere(np.logical_and((predictive_temp_sim[layer] < predictive_thresholds).numpy(), predictive_thresholds.numpy() > args.inp_thr*args.c_y[0]))
    for i in range(argwhere.shape[0]):
        ax2.axvspan(argwhere[i], argwhere[i]+1, color='b', alpha=0.1, lw=0)
    ax.yaxis.set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.yaxis.tick_left()
    ax2.yaxis.set_label_position('left')
    ax2.set_xlim(ax.get_xlim())
    # get rid of right margin
    #ax2.margins(0)
    ax.set_xlabel('Timesteps')
    plt.ylabel('Thresholds & Similarity Score')
    plt.xlim(0, 100)
    plt.legend()
    plt.show()
tensor([9.]) tensor([7.]) tensor([7.])
/home/lars/miniconda3/lib/python3.9/site-packages/matplotlib/patches.py:1111: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  xy = np.asarray(xy)
/home/lars/miniconda3/lib/python3.9/site-packages/matplotlib/patches.py:1111: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  xy = np.asarray(xy)

Analyze Weights Directly¶

In [ ]:
layers = [SNN.layers[0].fc.weight[:,:args.n_inputs]]
for i in range(1, len(SNN.layers)):
    layers.append(SNN.layers[i].fc.weight[:,:args.n_hidden[i-1]] @ layers[-1])

for i in range(len(SNN.layers)):
    plt.figure()
    plt.imshow(SNN.layers[i].fc.weight.detach(), vmax=0.1, vmin=-0.1)
    plt.colorbar()
    # plt.figure()
    # plt.imshow(SNN.layers[i].pred.weight.detach(), vmax=0.5, vmin=-0.5)
    # plt.colorbar()
for lay in layers:
    plt.figure()
    plt.imshow(lay.detach())
    plt.colorbar()
Error in callback <function flush_figures at 0x7f47940f8700> (for post_execute):
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/miniconda3/lib/python3.9/site-packages/PIL/ImageFile.py:495, in _save(im, fp, tile, bufsize)
    494 try:
--> 495     fh = fp.fileno()
    496     fp.flush()

AttributeError: '_idat' object has no attribute 'fileno'

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
File ~/miniconda3/lib/python3.9/site-packages/matplotlib_inline/backend_inline.py:126, in flush_figures()
    123 if InlineBackend.instance().close_figures:
    124     # ignore the tracking, just draw and close all figures
    125     try:
--> 126         return show(True)
    127     except Exception as e:
    128         # safely show traceback if in IPython, else raise
    129         ip = get_ipython()

File ~/miniconda3/lib/python3.9/site-packages/matplotlib_inline/backend_inline.py:90, in show(close, block)
     88 try:
     89     for figure_manager in Gcf.get_all_fig_managers():
---> 90         display(
     91             figure_manager.canvas.figure,
     92             metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
     93         )
     94 finally:
     95     show._to_draw = []

File ~/miniconda3/lib/python3.9/site-packages/IPython/core/display_functions.py:298, in display(include, exclude, metadata, transient, display_id, raw, clear, *objs, **kwargs)
    296     publish_display_data(data=obj, metadata=metadata, **kwargs)
    297 else:
--> 298     format_dict, md_dict = format(obj, include=include, exclude=exclude)
    299     if not format_dict:
    300         # nothing to display (e.g. _ipython_display_ took over)
    301         continue

File ~/miniconda3/lib/python3.9/site-packages/IPython/core/formatters.py:178, in DisplayFormatter.format(self, obj, include, exclude)
    176 md = None
    177 try:
--> 178     data = formatter(obj)
    179 except:
    180     # FIXME: log the exception
    181     raise

File <decorator-gen-2>:2, in __call__(self, obj)

File ~/miniconda3/lib/python3.9/site-packages/IPython/core/formatters.py:222, in catch_format_error(method, self, *args, **kwargs)
    220 """show traceback on failed format call"""
    221 try:
--> 222     r = method(self, *args, **kwargs)
    223 except NotImplementedError:
    224     # don't warn on NotImplementedErrors
    225     return self._check_return(None, args[0])

File ~/miniconda3/lib/python3.9/site-packages/IPython/core/formatters.py:339, in BaseFormatter.__call__(self, obj)
    337     pass
    338 else:
--> 339     return printer(obj)
    340 # Finally look for special method names
    341 method = get_real_method(obj, self.print_method)

File ~/miniconda3/lib/python3.9/site-packages/IPython/core/pylabtools.py:151, in print_figure(fig, fmt, bbox_inches, base64, **kwargs)
    148     from matplotlib.backend_bases import FigureCanvasBase
    149     FigureCanvasBase(fig)
--> 151 fig.canvas.print_figure(bytes_io, **kw)
    152 data = bytes_io.getvalue()
    153 if fmt == 'svg':

File ~/miniconda3/lib/python3.9/site-packages/matplotlib/backend_bases.py:2319, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2315 try:
   2316     # _get_renderer may change the figure dpi (as vector formats
   2317     # force the figure dpi to 72), so we need to set it again here.
   2318     with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2319         result = print_method(
   2320             filename,
   2321             facecolor=facecolor,
   2322             edgecolor=edgecolor,
   2323             orientation=orientation,
   2324             bbox_inches_restore=_bbox_inches_restore,
   2325             **kwargs)
   2326 finally:
   2327     if bbox_inches and restore_bbox:

File ~/miniconda3/lib/python3.9/site-packages/matplotlib/backend_bases.py:1648, in _check_savefig_extra_args.<locals>.wrapper(*args, **kwargs)
   1640     _api.warn_deprecated(
   1641         '3.3', name=name, removal='3.6',
   1642         message='%(name)s() got unexpected keyword argument "'
   1643                 + arg + '" which is no longer supported as of '
   1644                 '%(since)s and will become an error '
   1645                 '%(removal)s')
   1646     kwargs.pop(arg)
-> 1648 return func(*args, **kwargs)

File ~/miniconda3/lib/python3.9/site-packages/matplotlib/_api/deprecation.py:412, in delete_parameter.<locals>.wrapper(*inner_args, **inner_kwargs)
    402     deprecation_addendum = (
    403         f"If any parameter follows {name!r}, they should be passed as "
    404         f"keyword, not positionally.")
    405     warn_deprecated(
    406         since,
    407         name=repr(name),
   (...)
    410                  else deprecation_addendum,
    411         **kwargs)
--> 412 return func(*inner_args, **inner_kwargs)

File ~/miniconda3/lib/python3.9/site-packages/matplotlib/backends/backend_agg.py:541, in FigureCanvasAgg.print_png(self, filename_or_obj, metadata, pil_kwargs, *args)
    494 """
    495 Write the figure to a PNG file.
    496 
   (...)
    538     *metadata*, including the default 'Software' key.
    539 """
    540 FigureCanvasAgg.draw(self)
--> 541 mpl.image.imsave(
    542     filename_or_obj, self.buffer_rgba(), format="png", origin="upper",
    543     dpi=self.figure.dpi, metadata=metadata, pil_kwargs=pil_kwargs)

File ~/miniconda3/lib/python3.9/site-packages/matplotlib/image.py:1675, in imsave(fname, arr, vmin, vmax, cmap, format, origin, dpi, metadata, pil_kwargs)
   1673 pil_kwargs.setdefault("format", format)
   1674 pil_kwargs.setdefault("dpi", (dpi, dpi))
-> 1675 image.save(fname, **pil_kwargs)

File ~/miniconda3/lib/python3.9/site-packages/PIL/Image.py:2212, in Image.save(self, fp, format, **params)
   2209         fp = builtins.open(filename, "w+b")
   2211 try:
-> 2212     save_handler(self, fp, filename)
   2213 finally:
   2214     # do what we can to clean up
   2215     if open_fp:

File ~/miniconda3/lib/python3.9/site-packages/PIL/PngImagePlugin.py:1348, in _save(im, fp, filename, chunk, save_all)
   1346     _write_multiple_frames(im, fp, chunk, rawmode)
   1347 else:
-> 1348     ImageFile._save(im, _idat(fp, chunk), [("zip", (0, 0) + im.size, 0, rawmode)])
   1350 if info:
   1351     for info_chunk in info.chunks:

File ~/miniconda3/lib/python3.9/site-packages/PIL/ImageFile.py:509, in _save(im, fp, tile, bufsize)
    507 else:
    508     while True:
--> 509         l, s, d = e.encode(bufsize)
    510         fp.write(d)
    511         if s:

KeyboardInterrupt: 

Train output Projection¶

In [ ]:
from tqdm.notebook import tqdm
def train_out_proj(epochs, batch, cat, out_projs=None):
    # train output projections from all layers (and no layer)
    losses_out = []
    beta = 1.0
    lr = 1e-4
    augment = True
    optimizers = []
    print_interval = 10*batch
    if out_projs is None:
        out_projs = []
        out_proj_0 = simple_out(700, 20, beta=beta)
    else:
        for out_p in out_projs:
            out_p.train()
            out_p.reset()
        out_proj_0 = out_projs[0]
        out_projs = out_projs[1:]
    optim_0 = torch.optim.Adam(out_proj_0.parameters(), lr=lr)
    for lay in range(len(SNN.layers)):
        if len(out_projs) <= lay:
            if cat:
                out_projs.append(simple_out(sum(args.n_hidden[:lay+1])+700, 20, beta=beta))
            else:
                out_projs.append(simple_out(args.n_hidden[lay], 20, beta=beta))
        optimizers.append(torch.optim.Adam(out_projs[lay].parameters(), lr=lr))
        optimizers[-1].zero_grad()
    SNN.eval()
    acc = []
    target = batch_size*[-1]
    correct = (len(SNN.layers) + 1)*[0]
    with torch.no_grad():
        pbar = tqdm(total=len(train_loader)*epochs)
        while len(losses_out)*batch < len(train_loader)*epochs:
            data, target = train_loader.next_item(target, contrastive=True)
            SNN.reset(0)
            logit_lists = [[] for _ in range(len(SNN.layers)+1)]
            data = data.squeeze()
            if augment:
                data = augment_shd(data)
            for step in range(data.shape[0]):
                data_step = data[step].float().to(device)
                target = target.to(device)
                logits, _, _ = SNN(data_step, 0)
                if step == args.n_time_bins-1:
                    _, logts = out_proj_0(data_step, target)
                    logit_lists[0] = logts
                    for lay in range(len(SNN.layers)):
                        if cat:
                            data_step = torch.cat([data_step, logits[lay]], dim=-1)
                        else:
                            data_step = logits[lay]
                        _, logts = out_projs[lay](data_step, target)
                        logit_lists[lay+1] = logts
                else:
                    out_proj_0(data_step, None)
                    for lay in range(len(SNN.layers)):
                        if cat:
                            data_step = torch.cat([data_step, logits[lay]], dim=-1)
                        else:
                            data_step = logits[lay]
                        out_projs[lay](data_step, None)
            
            preds = [logit_lists[lay].argmax(axis=-1) for lay in range(len(SNN.layers)+1)]
            correct = [correct[lay] + (preds[lay] == target).sum() for lay in range(len(SNN.layers)+1)]
            out_proj_0.reset()
            for i, out_proj in enumerate(out_projs):
                out_proj.reset()

            losses_out.append(torch.tensor([torch.nn.functional.cross_entropy(logit_lists[lay], target.squeeze().long()) for lay in range(len(SNN.layers)+1)], requires_grad=False))

            optim_0.step()
            optim_0.zero_grad()
            for opt in optimizers:
                opt.step()
                opt.zero_grad()
            
            if len(losses_out)*batch % print_interval == 0:
                pbar.write(f'Cross Entropy Loss: {(torch.stack(losses_out)[-print_interval//batch:].sum(dim=0)/(print_interval//batch)).numpy()}\n' +
                           f'Correct: {100*np.array(correct)/print_interval}%')
                acc.append(np.array(correct)/print_interval)
                correct = (len(SNN.layers) + 1)*[0]
            pbar.update(batch)
    return [out_proj_0, *out_projs], np.asarray(acc), torch.stack(losses_out)

with torch.no_grad():
    if args.augment:
        n_epochs = 100
        cat = True
        # if already trained, load the output projections
        if os.path.exists(model_name[:-3]+'_out_projs.pt'):
            out_projs = torch.load(model_name[:-3]+'_out_projs.pt', map_location=device)
        else:
            out_projs, acc, losses_out = train_out_proj(n_epochs, batch_size, cat)
            torch.save(out_projs, model_name[:-3]+'_out_projs.pt')
In [ ]:
if not args.augment:
    # if snn_samples already exist, don't recompute
    try:
        snn_samples
    except:
        snn_samples, targets = get_samples(SNN, train_loader, args.n_hidden, device)
    cat = True
    test_accs = []
    train_accs = []
    for i in range(10):
        with torch.no_grad():
            out_projs, acc, losses_out = train_out_proj_fast(SNN, args, 30, 60, snn_samples, train_loader.y, cat=cat, lr=3e-4, weight_decay=0)
        print('Mean abs weights', out_projs[-1].out_proj.weight.abs().mean())
        test_accs.append(get_accuracy(SNN, out_projs, test_loader, device, cat=cat)[0])
        train_accs.append(get_accuracy(SNN, out_projs, train_loader, device, cat=cat)[0])
    test_accs = torch.stack([torch.tensor(ta) for ta in test_accs])
    train_accs = torch.stack([torch.tensor(ta) for ta in train_accs])
    print(f'Fast Classifier Mean Test Accuracy: {100*torch.mean(test_accs, dim=0)}, Std: {100*torch.std(test_accs, dim=0)}')
    print(f'Fast Classifier Mean Train Accuracy: {100*torch.mean(train_accs, dim=0)}, Std: {100*torch.std(train_accs, dim=0)}')
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.1872716 3.1436896 2.324911  1.8332611 1.7116711]
Correct: [15.91466405 28.84992643 41.88327612 53.29818538 56.47376165]%
Cross Entropy Loss: [2.463053   1.4402385  0.9932619  0.78709584 0.7274044 ]
Correct: [29.61010299 52.52574792 66.63805787 74.35017165 76.85139774]%
Cross Entropy Loss: [2.1605024  1.1310219  0.75093096 0.6065797  0.5781179 ]
Correct: [35.49534085 61.73369299 74.46051986 79.78175576 80.78715056]%
Cross Entropy Loss: [1.9743724  0.98842406 0.6655614  0.52868205 0.4756272 ]
Correct: [40.10544385 66.84649338 78.05296714 82.74889652 83.88916135]%
Cross Entropy Loss: [1.8470854  0.86089724 0.5742121  0.44715923 0.39279133]
Correct: [43.79597842 71.21137813 81.37567435 85.63021089 87.02795488]%
Cross Entropy Loss: [1.7041218  0.7355151  0.4549692  0.36407518 0.30456907]
Correct: [45.63511525 75.14713095 85.26238352 87.86169691 89.66405101]%
Cross Entropy Loss: [1.6372793  0.69917977 0.42999554 0.3589159  0.29322165]
Correct: [47.27807749 76.77783227 85.66699362 88.03334968 90.17900932]%
Cross Entropy Loss: [1.544904  0.6397305 0.4042687 0.2856387 0.2902844]
Correct: [50.73565473 78.64149093 86.75821481 90.57135851 90.39970574]%
Cross Entropy Loss: [1.4992337  0.561572   0.34116337 0.2432679  0.20581922]
Correct: [50.94409024 80.93428151 89.02648357 92.21432075 93.12162825]%
Cross Entropy Loss: [1.4029716  0.5246412  0.30427465 0.24020304 0.22663431]
Correct: [53.82540461 82.24619912 90.31387935 92.09171162 92.71701815]%
Cross Entropy Loss: [1.3696545  0.47617233 0.27150342 0.21061409 0.1887534 ]
Correct: [54.48749387 84.67385974 91.47866601 93.37910741 93.84502207]%
Cross Entropy Loss: [1.3367599  0.46279103 0.23594631 0.1949752  0.17442143]
Correct: [56.04462972 84.64933791 92.55762629 94.01667484 94.1760667 ]%
Cross Entropy Loss: [1.2394575  0.42377707 0.2390069  0.19376123 0.1666555 ]
Correct: [58.01863659 85.96125552 92.5698872  93.63658656 94.4825895 ]%
Cross Entropy Loss: [1.2089213  0.3754208  0.19392681 0.15874341 0.12948033]
Correct: [59.10985777 87.45708681 94.1760667  94.7768514  96.03972536]%
Cross Entropy Loss: [1.1781707  0.37796757 0.19277297 0.13802704 0.1132929 ]
Correct: [61.16969103 87.27317312 94.04119667 95.72094164 96.39529181]%
Cross Entropy Loss: [1.1654673  0.35604432 0.19146349 0.14440024 0.13116387]
Correct: [60.5443845  88.37665522 94.00441393 95.40215792 96.11329083]%
Cross Entropy Loss: [1.0976348  0.30079675 0.15331014 0.11085778 0.08470406]
Correct: [62.21186856 90.58361942 95.72094164 96.51790093 97.47425208]%
Cross Entropy Loss: [1.0685236  0.29694378 0.14392856 0.1122706  0.06621685]
Correct: [64.21039725 90.77979402 95.73320255 96.56694458 98.24668955]%
Cross Entropy Loss: [1.0360285  0.28520513 0.12562898 0.09520219 0.07301001]
Correct: [64.51692006 91.09857773 96.65277097 97.10642472 97.76851398]%
Cross Entropy Loss: [0.9997829  0.27274328 0.11821497 0.08664852 0.07347295]
Correct: [65.35066209 91.34379598 96.8857283  97.21677293 97.75625307]%
Cross Entropy Loss: [0.98923266 0.25430578 0.11565954 0.08620941 0.07150187]
Correct: [65.44874939 92.34919078 96.93477195 97.40068661 97.85434036]%
Cross Entropy Loss: [0.9962766  0.23677072 0.12023983 0.06941527 0.06794292]
Correct: [66.38057872 92.50858264 96.71407553 98.25895047 98.12408043]%
Cross Entropy Loss: [0.93273515 0.23093246 0.09307903 0.06428653 0.06228227]
Correct: [68.15841099 93.18293281 97.74399215 98.34477685 98.32025503]%
Cross Entropy Loss: [0.9235989  0.23393106 0.10302919 0.0937719  0.0582747 ]
Correct: [68.25649828 92.39823443 97.36390387 96.95929377 98.33251594]%
Cross Entropy Loss: [0.9203787  0.2202786  0.09355038 0.05742612 0.03718374]
Correct: [68.86954389 92.88867092 97.58460029 98.55321236 99.0926925 ]%
Cross Entropy Loss: [0.9017033  0.18595962 0.06985462 0.05069679 0.05567232]
Correct: [69.13928396 94.5806768  98.52869053 98.7984306  98.35703776]%
Cross Entropy Loss: [0.8860314  0.20338023 0.08454282 0.07205418 0.06439789]
Correct: [69.47032859 93.74693477 97.65816577 97.73173124 97.82981854]%
Cross Entropy Loss: [0.85007566 0.16172642 0.06845957 0.05027504 0.05842838]
Correct: [70.89259441 95.65963708 98.44286415 98.71260422 98.27121138]%
Cross Entropy Loss: [0.84972966 0.15339603 0.06147075 0.06692296 0.06831121]
Correct: [71.21137813 96.03972536 98.68808239 97.90338401 98.1976459 ]%
Cross Entropy Loss: [0.86331654 0.19340706 0.08035859 0.03712239 0.04313882]
Correct: [71.10102992 94.94850417 98.09955861 99.21530162 98.83521334]%
Mean abs weights tensor(0.0239, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 45.80%
From layer 1:
Accuracy: 67.89%
From layer 2:
Accuracy: 74.78%
From layer 3:
Accuracy: 75.18%
From layer 4:
Accuracy: 77.08%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 71.38%
From layer 1:
Accuracy: 95.02%
From layer 2:
Accuracy: 99.15%
From layer 3:
Accuracy: 99.68%
From layer 4:
Accuracy: 96.17%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.2056136 3.3595018 2.5578706 1.8632474 1.7972437]
Correct: [16.00049044 27.78322707 40.5590976  53.23688082 55.21088769]%
Cross Entropy Loss: [2.5309207 1.482339  1.038259  0.8312534 0.7516103]
Correct: [29.27905836 51.98626778 65.21579205 73.01373222 76.17704757]%
Cross Entropy Loss: [2.2038283  1.1757452  0.7548449  0.63051116 0.5463139 ]
Correct: [34.64933791 60.97351643 74.35017165 79.03384012 81.65767533]%
Cross Entropy Loss: [1.9805763  0.9832883  0.6415817  0.55061483 0.4850844 ]
Correct: [39.2594409  66.98136341 78.39627268 82.13585091 84.37959784]%
Cross Entropy Loss: [1.8289003  0.85267454 0.55737317 0.43624127 0.38633227]
Correct: [43.08484551 71.44433546 81.57184895 85.20107896 87.45708681]%
Cross Entropy Loss: [1.7603719  0.7941842  0.47148356 0.38691586 0.34094897]
Correct: [45.16920059 73.46738597 84.12211869 87.07699853 88.47474252]%
Cross Entropy Loss: [1.6479888  0.6881737  0.4219232  0.33900353 0.29466173]
Correct: [47.6949485  76.76557136 85.89995096 89.02648357 90.53457577]%
Cross Entropy Loss: [1.5874764  0.6404749  0.3832031  0.3667489  0.31605873]
Correct: [49.301128   78.22461991 87.22412948 87.95978421 89.94605199]%
Cross Entropy Loss: [1.5148996  0.5874625  0.31833452 0.2530897  0.22447881]
Correct: [50.94409024 80.05149583 89.36978911 91.76066699 92.82736636]%
Cross Entropy Loss: [1.3929242  0.517211   0.2957919  0.23422898 0.19924119]
Correct: [53.911231   82.78567925 90.24031388 92.32466896 93.56302109]%
Cross Entropy Loss: [1.3655514  0.46343637 0.26418233 0.18656996 0.17288554]
Correct: [55.33349681 84.29377146 91.90779794 94.18832761 94.60519863]%
Cross Entropy Loss: [1.3460088  0.44891194 0.2609801  0.18687023 0.16876407]
Correct: [55.4193232  85.01716528 91.20892594 94.01667484 94.71554684]%
Cross Entropy Loss: [1.2780318  0.42506018 0.23402876 0.16550633 0.1458253 ]
Correct: [56.96419814 86.02256008 92.72927906 94.96076508 95.61059343]%
Cross Entropy Loss: [1.1999463  0.37642714 0.2073232  0.17451656 0.16576539]
Correct: [59.22020598 87.71456596 93.64884747 94.47032859 94.76459049]%
Cross Entropy Loss: [1.1591609  0.3495398  0.18536784 0.13290927 0.11321939]
Correct: [60.25012261 88.94065718 94.10250123 95.92937715 96.46885728]%
Cross Entropy Loss: [1.1474069  0.345629   0.19094934 0.1126137  0.09575427]
Correct: [61.26777832 88.80578715 94.06571849 96.75085826 97.08190289]%
Cross Entropy Loss: [1.1157053  0.32503557 0.16931117 0.14722033 0.10864975]
Correct: [62.59195684 89.54144188 95.03433055 95.85581167 96.70181462]%
Cross Entropy Loss: [1.1233001  0.3216199  0.17507222 0.12341614 0.12036664]
Correct: [62.27317312 89.32074546 94.74006866 96.11329083 96.06424718]%
Cross Entropy Loss: [1.084442   0.27813444 0.15503414 0.1219961  0.07550655]
Correct: [63.24178519 91.24570868 95.26728789 96.10102992 97.75625307]%
Cross Entropy Loss: [1.0297537  0.27438426 0.11474238 0.07875086 0.06496222]
Correct: [64.82344286 91.58901422 96.81216282 97.96468857 98.13634134]%
Cross Entropy Loss: [1.0179255  0.26768327 0.12178159 0.10539071 0.09340606]
Correct: [65.43648847 91.31927415 96.53016184 96.84894556 96.89798921]%
Cross Entropy Loss: [0.9677667  0.23042907 0.09970601 0.08376812 0.08283429]
Correct: [67.05492889 93.06032369 97.68268759 97.53555665 97.52329573]%
Cross Entropy Loss: [0.93068594 0.22338991 0.08591006 0.06659058 0.03558221]
Correct: [67.83962727 93.42815105 98.14860226 98.30799411 99.38695439]%
Cross Entropy Loss: [0.92695904 0.21457809 0.08566278 0.1193421  0.04286044]
Correct: [68.64884747 93.56302109 97.85434036 96.59146641 98.85973516]%
Cross Entropy Loss: [0.9117389  0.20471463 0.07515663 0.06993123 0.03466035]
Correct: [68.62432565 94.1760667  98.45512506 98.02599313 99.11721432]%
Cross Entropy Loss: [0.91249263 0.23524472 0.1092992  0.09192672 0.07934052]
Correct: [67.81510544 92.58214811 97.06964198 97.26581658 97.53555665]%
Cross Entropy Loss: [0.9042726  0.1978319  0.09377304 0.07481949 0.11384055]
Correct: [69.10250123 94.10250123 97.43746935 97.70720942 97.02059833]%
Cross Entropy Loss: [0.8498304  0.18166998 0.08278241 0.03884922 0.07832021]
Correct: [70.45120157 94.43354586 97.90338401 99.17851888 97.65816577]%
Cross Entropy Loss: [0.8266773  0.16396347 0.08591412 0.03990128 0.10072411]
Correct: [71.6895537  95.50024522 97.57233938 99.16625797 96.87346739]%
Cross Entropy Loss: [0.8071401  0.15629496 0.06097132 0.06256734 0.06781789]
Correct: [72.24129475 95.6841589  98.63903874 98.17312408 97.75625307]%
Mean abs weights tensor(0.0242, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 45.05%
From layer 1:
Accuracy: 69.08%
From layer 2:
Accuracy: 71.91%
From layer 3:
Accuracy: 73.19%
From layer 4:
Accuracy: 74.87%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 72.55%
From layer 1:
Accuracy: 95.29%
From layer 2:
Accuracy: 97.66%
From layer 3:
Accuracy: 96.78%
From layer 4:
Accuracy: 98.93%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.324941  3.5641503 2.167204  1.99385   1.8944976]
Correct: [15.33840118 26.79009318 42.52084355 49.20304071 55.49288867]%
Cross Entropy Loss: [2.4954576  1.558196   1.0214167  0.82206196 0.73318094]
Correct: [29.76949485 50.1961746  65.43648847 72.91564492 75.9073075 ]%
Cross Entropy Loss: [2.1982431  1.2117525  0.7647834  0.61927146 0.54909784]
Correct: [35.71603727 59.58803335 74.46051986 79.54879843 81.8293281 ]%
Cross Entropy Loss: [1.9816691  1.0023011  0.6322607  0.50368994 0.45409152]
Correct: [39.72535557 66.72388426 78.65375184 83.1289848  85.18881805]%
Cross Entropy Loss: [1.8514938  0.8848068  0.5756617  0.50811195 0.42373785]
Correct: [43.56302109 70.76998529 80.83619421 83.44776851 86.48847474]%
Cross Entropy Loss: [1.6841191  0.7635488  0.47471923 0.40711936 0.3217422 ]
Correct: [46.17459539 74.77930358 83.74203041 86.94212849 89.39431094]%
Cross Entropy Loss: [1.6386919  0.70924807 0.41891712 0.3423481  0.30363506]
Correct: [48.41834232 76.54487494 86.67238843 88.40117705 90.25257479]%
Cross Entropy Loss: [1.5386631  0.61993647 0.3654886  0.31153786 0.30298343]
Correct: [50.45365375 79.67140755 87.92300147 89.79892104 90.362923  ]%
Cross Entropy Loss: [1.4742762  0.57251674 0.32781553 0.24483165 0.23077434]
Correct: [52.09661599 80.77488965 89.38205002 92.11623345 92.43501717]%
Cross Entropy Loss: [1.4021066  0.5318807  0.30169746 0.2463513  0.20541684]
Correct: [54.34036292 82.44237371 90.53457577 92.09171162 93.52623835]%
Cross Entropy Loss: [1.3609004  0.4975689  0.27264592 0.22916225 0.20252237]
Correct: [55.8239333  83.86463953 91.41736145 92.63119176 93.80823933]%
Cross Entropy Loss: [1.2960101  0.46684906 0.23952818 0.21564874 0.16762549]
Correct: [56.91515449 84.41638058 92.65571359 93.23197646 95.05885238]%
Cross Entropy Loss: [1.2774516  0.43955156 0.2160222  0.1738402  0.13685773]
Correct: [57.77341834 85.43403629 93.45267288 94.12702305 95.52476704]%
Cross Entropy Loss: [1.2328824  0.3953155  0.21106435 0.15347378 0.12532842]
Correct: [58.72976949 87.0524767  93.79597842 95.19372241 95.99068171]%
Cross Entropy Loss: [1.186512   0.37037873 0.18252431 0.12997529 0.09352573]
Correct: [59.99264345 88.30308975 94.51937224 96.13781265 97.51103482]%
Cross Entropy Loss: [1.1830083  0.3465231  0.1814412  0.16480218 0.12391769]
Correct: [60.34820991 88.83030897 94.66650319 95.04659147 96.05198627]%
Cross Entropy Loss: [1.1120533  0.33547872 0.17535713 0.11841699 0.15876324]
Correct: [62.71456596 89.10004904 94.72780775 96.35850907 94.86267778]%
Cross Entropy Loss: [1.0661964  0.32379746 0.14929199 0.11706953 0.1185933 ]
Correct: [63.78126533 89.2594409  95.50024522 96.35850907 96.23589995]%
Cross Entropy Loss: [1.0448747  0.28769663 0.14577052 0.11159009 0.07654792]
Correct: [64.72535557 91.061795   95.99068171 96.64051005 97.73173124]%
Cross Entropy Loss: [1.0634643  0.27672145 0.11562611 0.10590365 0.06841727]
Correct: [64.97057381 91.24570868 96.92251103 96.59146641 98.06277587]%
Cross Entropy Loss: [1.0070212  0.26485303 0.11620875 0.08465157 0.09021208]
Correct: [65.42422756 91.89553703 96.97155468 97.37616479 97.20451202]%
Cross Entropy Loss: [0.98477656 0.28237838 0.116499   0.10662637 0.09386032]
Correct: [66.38057872 90.93918588 97.02059833 96.71407553 96.91025012]%
Cross Entropy Loss: [0.9817468  0.25369784 0.09833461 0.09275801 0.08061231]
Correct: [66.71162334 92.25110348 97.62138303 97.05738107 97.25355566]%
Cross Entropy Loss: [0.91994953 0.22278142 0.07945165 0.0501484  0.03278602]
Correct: [68.33006376 93.45267288 98.44286415 98.81069152 99.20304071]%
Cross Entropy Loss: [0.94546384 0.25147757 0.09238899 0.0729221  0.06501099]
Correct: [67.59440902 92.10397254 97.63364394 97.9892104  98.44286415]%
Cross Entropy Loss: [0.8817579  0.1967987  0.07712112 0.05854746 0.03715175]
Correct: [69.49485042 94.11476214 98.24668955 98.47964689 99.11721432]%
Cross Entropy Loss: [0.91573596 0.1897944  0.07192733 0.09692039 0.02604249]
Correct: [69.0779794  94.49485042 98.23442864 97.00833742 99.46051986]%
Cross Entropy Loss: [0.86333513 0.18468462 0.06465902 0.0467017  0.04512371]
Correct: [70.46346248 94.70328592 98.62677783 98.90877881 98.72486513]%
Cross Entropy Loss: [0.8436995  0.1664356  0.06542653 0.03883252 0.03336268]
Correct: [70.36537518 95.36537518 98.5899951  99.16625797 99.23982344]%
Cross Entropy Loss: [0.85385126 0.17009127 0.06540503 0.02821399 0.01392475]
Correct: [70.38989701 95.10789603 98.68808239 99.54634625 99.92643453]%
Mean abs weights tensor(0.0232, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 45.41%
From layer 1:
Accuracy: 66.56%
From layer 2:
Accuracy: 74.16%
From layer 3:
Accuracy: 77.43%
From layer 4:
Accuracy: 78.45%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 72.79%
From layer 1:
Accuracy: 96.35%
From layer 2:
Accuracy: 98.71%
From layer 3:
Accuracy: 99.73%
From layer 4:
Accuracy: 99.99%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [3.988924  3.1324291 2.6133893 2.0082664 1.7002865]
Correct: [17.49632173 27.42766062 38.94065718 51.05443845 56.13045611]%
Cross Entropy Loss: [2.4034095 1.4766755 1.042177  0.813486  0.7545184]
Correct: [30.14958313 52.24374693 64.18587543 72.58460029 75.07356547]%
Cross Entropy Loss: [2.1636655  1.1886683  0.8245683  0.63469    0.56025475]
Correct: [35.87542913 60.18881805 72.25355566 78.35948995 81.63315351]%
Cross Entropy Loss: [1.933696   0.9693465  0.64669544 0.5180526  0.455652  ]
Correct: [40.04413928 67.49632173 77.66061795 82.77341834 84.84551251]%
Cross Entropy Loss: [1.8000529 0.8509969 0.5312405 0.418528  0.3724982]
Correct: [43.73467386 71.87346739 81.62089259 85.8386464  87.54291319]%
Cross Entropy Loss: [1.7256318  0.7600211  0.50571114 0.39436883 0.3519412 ]
Correct: [46.65277097 75.08582639 83.19028936 87.10152035 88.94065718]%
Cross Entropy Loss: [1.6216217  0.69503504 0.43294814 0.36777028 0.33592716]
Correct: [48.24668955 76.42226582 85.10299166 87.7513487  89.02648357]%
Cross Entropy Loss: [1.5195189  0.6152837  0.36658275 0.3207121  0.26731536]
Correct: [50.77243747 79.79401667 87.89847965 89.73761648 91.38057872]%
Cross Entropy Loss: [1.4328399  0.54990774 0.32905525 0.2574844  0.23479114]
Correct: [52.7096616  82.01324179 89.22265817 91.63805787 92.80284453]%
Cross Entropy Loss: [1.4244062  0.5314243  0.311795   0.22268677 0.1960473 ]
Correct: [54.10740559 82.61402648 90.17900932 93.02354095 93.83276116]%
Cross Entropy Loss: [1.3812139  0.4833222  0.2883134  0.2266049  0.18003951]
Correct: [54.99019127 84.90681707 90.64492398 92.54536538 94.0779794 ]%
Cross Entropy Loss: [1.2676761  0.43009827 0.24442723 0.18956302 0.14598212]
Correct: [57.54046101 85.71603727 92.22658166 93.61206474 95.2795488 ]%
Cross Entropy Loss: [1.2385132  0.43745086 0.23291476 0.18338291 0.14943925]
Correct: [58.39872487 85.4462972  92.36145169 94.26189308 94.86267778]%
Cross Entropy Loss: [1.2558703  0.40281573 0.21592072 0.15124676 0.12518702]
Correct: [58.49681216 86.62334478 93.31780284 95.2795488  96.11329083]%
Cross Entropy Loss: [1.1857114  0.36193734 0.18302616 0.13203864 0.11002421]
Correct: [60.07846984 88.2540461  94.5806768  96.088769   96.40755272]%
Cross Entropy Loss: [1.1280024  0.33055058 0.17601904 0.15089615 0.10825602]
Correct: [61.62334478 89.54144188 95.03433055 95.32859245 96.53016184]%
Cross Entropy Loss: [1.0969436  0.33614233 0.16785593 0.11069207 0.08921219]
Correct: [63.07013242 89.27170181 94.99754782 96.60372732 97.33938205]%
Cross Entropy Loss: [1.0771295  0.2995079  0.14733732 0.09099505 0.0805084 ]
Correct: [63.92839627 90.59588033 95.62285434 97.42520844 97.486513  ]%
Cross Entropy Loss: [1.0687064  0.27161282 0.12266469 0.10765455 0.06207873]
Correct: [63.97743992 91.79744973 96.77538009 96.72633644 98.28347229]%
Cross Entropy Loss: [1.0154715  0.28039503 0.11916143 0.09391033 0.05867408]
Correct: [65.69396763 91.24570868 96.77538009 97.16772928 98.47964689]%
Cross Entropy Loss: [1.0040084  0.2532912  0.1227235  0.07905608 0.08915396]
Correct: [65.43648847 92.09171162 96.54242276 97.7930358  97.20451202]%
Cross Entropy Loss: [0.9828105  0.21825738 0.10432242 0.08528934 0.0611566 ]
Correct: [66.84649338 93.83276116 97.40068661 97.30259931 98.23442864]%
Cross Entropy Loss: [0.9722741  0.22587538 0.1088616  0.07630668 0.06865072]
Correct: [66.4664051  93.06032369 97.06964198 98.02599313 98.11181952]%
Cross Entropy Loss: [0.9029567  0.20196205 0.09453553 0.08083413 0.08253992]
Correct: [69.47032859 94.13928396 97.67042668 97.49877391 97.53555665]%
Cross Entropy Loss: [0.93391335 0.19596069 0.08437563 0.04748533 0.07105881]
Correct: [68.35458558 94.71554684 98.07503678 98.87199608 97.84207945]%
Cross Entropy Loss: [0.8545353  0.17513451 0.07545985 0.05404242 0.05624818]
Correct: [70.23050515 95.3776361  98.30799411 98.55321236 98.3938205 ]%
Cross Entropy Loss: [0.8437743  0.17947221 0.0813463  0.08740982 0.05847281]
Correct: [71.05198627 94.96076508 98.01373222 97.05738107 98.1976459 ]%
Cross Entropy Loss: [0.8890023  0.19300587 0.07749287 0.06227902 0.04516035]
Correct: [69.88719961 94.38450221 98.33251594 98.33251594 98.62677783]%
Cross Entropy Loss: [0.83754176 0.16662489 0.06607958 0.05965359 0.0796144 ]
Correct: [71.40755272 95.30407062 98.4919078  98.2957332  97.35164296]%
Cross Entropy Loss: [0.8502513  0.15231785 0.05713814 0.04530533 0.0444615 ]
Correct: [70.74546346 96.03972536 98.88425699 98.8965179  98.61451692]%
Mean abs weights tensor(0.0241, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 44.74%
From layer 1:
Accuracy: 68.51%
From layer 2:
Accuracy: 75.53%
From layer 3:
Accuracy: 77.21%
From layer 4:
Accuracy: 79.20%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 73.80%
From layer 1:
Accuracy: 95.97%
From layer 2:
Accuracy: 99.03%
From layer 3:
Accuracy: 99.88%
From layer 4:
Accuracy: 99.80%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.6227293 3.3122234 2.4990158 2.032588  1.6780014]
Correct: [15.03187837 27.13339872 41.14762138 50.15939186 56.9274154 ]%
Cross Entropy Loss: [2.5443745  1.4194658  0.99239236 0.8217521  0.72340566]
Correct: [29.19323198 52.74644434 66.68710152 73.11181952 76.87591957]%
Cross Entropy Loss: [2.1909614 1.141158  0.7558997 0.596321  0.5258548]
Correct: [34.9435998  61.73369299 74.64443355 80.46836685 83.22707209]%
Cross Entropy Loss: [1.998334   0.9508227  0.6387612  0.54015654 0.4599532 ]
Correct: [39.51692006 68.06032369 78.45757724 83.07994115 85.21333987]%
Cross Entropy Loss: [1.8620007  0.8320849  0.5388197  0.41595495 0.3594737 ]
Correct: [42.3737126  71.62824914 81.62089259 85.94899461 88.40117705]%
Cross Entropy Loss: [1.7234403  0.73059595 0.4366806  0.36158878 0.31049895]
Correct: [46.07650809 75.51495831 85.50760177 88.08239333 89.59048553]%
Cross Entropy Loss: [1.6702443  0.66301334 0.42448047 0.34175032 0.35053933]
Correct: [48.11181952 77.69740069 86.51299657 88.94065718 89.38205002]%
Cross Entropy Loss: [1.564496   0.64910513 0.39027712 0.32333708 0.32050017]
Correct: [49.3992153  78.33496812 87.23639039 89.71309465 90.33840118]%
Cross Entropy Loss: [1.4760293  0.5700429  0.33421382 0.2553209  0.22513282]
Correct: [51.54487494 81.05689063 88.8548308  91.55223149 92.39823443]%
Cross Entropy Loss: [1.4090142  0.5001672  0.296235   0.21848905 0.19267294]
Correct: [54.00931829 83.22707209 90.66944581 93.1706719  93.95537028]%
Cross Entropy Loss: [1.3808565  0.46514562 0.26722318 0.20727138 0.16529001]
Correct: [55.2231486  85.15203531 91.42962236 93.23197646 94.62972045]%
Cross Entropy Loss: [1.300876   0.43479815 0.24435584 0.19516654 0.18536007]
Correct: [56.63315351 85.37273173 92.38597352 93.83276116 93.9798921 ]%
Cross Entropy Loss: [1.2490568  0.4021349  0.2218795  0.16986145 0.13077828]
Correct: [57.96959294 86.8440412  93.12162825 94.70328592 95.95389897]%
Cross Entropy Loss: [1.194781   0.38148913 0.19828737 0.15434018 0.11747424]
Correct: [60.20107896 87.40804316 94.06571849 95.41441883 96.4811182 ]%
Cross Entropy Loss: [1.1902742  0.3513643  0.18725683 0.13932124 0.12292454]
Correct: [60.02942619 88.78126533 94.37224129 95.56154978 95.95389897]%
Cross Entropy Loss: [1.1520844  0.34087902 0.17548335 0.13175164 0.11095686]
Correct: [61.61108386 88.76900441 94.97302599 95.95389897 96.35850907]%
Cross Entropy Loss: [1.1250765  0.3235074  0.16791598 0.13675725 0.11150184]
Correct: [62.71456596 89.13683178 95.14467876 95.6841589  96.50564002]%
Cross Entropy Loss: [1.0755826  0.29459032 0.1392301  0.11479837 0.08126719]
Correct: [63.67091712 90.60814125 96.35850907 96.32172634 97.486513  ]%
Cross Entropy Loss: [1.0744154  0.28409076 0.13477576 0.12542655 0.10026763]
Correct: [63.56056891 91.20892594 96.21137813 96.02746444 96.8857283 ]%
Cross Entropy Loss: [1.0377376  0.2779014  0.12345953 0.08118957 0.06849066]
Correct: [64.78666013 91.22118686 96.67729279 97.76851398 97.94016675]%
Cross Entropy Loss: [0.99990654 0.24207073 0.10484799 0.07835941 0.07164519]
Correct: [66.31927415 92.74153997 97.22903384 97.64590486 97.81755763]%
Cross Entropy Loss: [0.9796146  0.23464899 0.10380007 0.07130498 0.06221569]
Correct: [66.35605689 92.70475723 97.30259931 98.11181952 98.09955861]%
Cross Entropy Loss: [0.9626476  0.20477287 0.08840603 0.06318509 0.04398413]
Correct: [67.30014713 94.24963217 98.05051496 98.44286415 98.92103973]%
Cross Entropy Loss: [0.9072317  0.2276586  0.10449994 0.08032357 0.0927543 ]
Correct: [69.00441393 93.08484551 97.25355566 97.6949485  97.20451202]%
Cross Entropy Loss: [0.93354005 0.1991149  0.08352045 0.05562714 0.03213453]
Correct: [67.99901913 93.91858754 98.06277587 98.56547327 99.35017165]%
Cross Entropy Loss: [0.9068943  0.19786595 0.09104298 0.06897325 0.04347578]
Correct: [69.01667484 94.12702305 97.65816577 98.01373222 98.85973516]%
Cross Entropy Loss: [0.846317   0.18319336 0.07764319 0.04932541 0.03260033]
Correct: [70.86807258 94.74006866 98.20990682 98.74938695 99.28886709]%
Cross Entropy Loss: [0.85570025 0.17203462 0.0563289  0.07962129 0.03103863]
Correct: [70.63511525 95.2795488  99.06817067 97.91564492 99.32564983]%
Cross Entropy Loss: [0.8363541  0.19039999 0.07366792 0.12694584 0.07575922]
Correct: [71.49337911 94.12702305 98.02599313 96.21137813 97.67042668]%
Cross Entropy Loss: [0.8169729  0.14992146 0.05272785 0.03007926 0.15538152]
Correct: [72.06964198 96.16233448 99.05590976 99.48504169 95.19372241]%
Mean abs weights tensor(0.0242, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 44.83%
From layer 1:
Accuracy: 67.31%
From layer 2:
Accuracy: 72.17%
From layer 3:
Accuracy: 76.72%
From layer 4:
Accuracy: 75.62%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 74.77%
From layer 1:
Accuracy: 97.18%
From layer 2:
Accuracy: 99.47%
From layer 3:
Accuracy: 99.57%
From layer 4:
Accuracy: 95.62%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.0144196 3.4878871 2.3410316 1.8425995 2.0374084]
Correct: [16.29475233 29.29131927 40.09318293 53.5188818  53.99705738]%
Cross Entropy Loss: [2.3983402 1.4973923 1.0394697 0.8250453 0.7729796]
Correct: [30.65228053 51.18930848 64.9583129  73.01373222 74.60765081]%
Cross Entropy Loss: [2.1310823 1.2015562 0.7879315 0.6552061 0.5681548]
Correct: [36.18195194 59.40411967 73.05051496 78.83766552 81.33889161]%
Cross Entropy Loss: [1.9213041  0.9902826  0.635482   0.53706896 0.4645162 ]
Correct: [40.5590976  66.39283963 78.06522805 82.49141736 84.7351643 ]%
Cross Entropy Loss: [1.8263688  0.85059613 0.55404454 0.4179879  0.3957129 ]
Correct: [43.62432565 71.24816086 81.46150074 86.19421285 86.6478666 ]%
Cross Entropy Loss: [1.6964254  0.79455066 0.45513493 0.39159617 0.34748095]
Correct: [46.23589995 73.45512506 84.79646886 87.19960765 88.41343796]%
Cross Entropy Loss: [1.5983654  0.6697549  0.38751093 0.32011092 0.29438794]
Correct: [48.34477685 77.48896518 87.0524767  89.66405101 90.10544385]%
Cross Entropy Loss: [1.5216894  0.6108541  0.38179174 0.3051002  0.28664818]
Correct: [51.09122119 79.69592938 87.73908779 89.9583129  90.44874939]%
Cross Entropy Loss: [1.442836   0.5885045  0.34718186 0.28229165 0.252555  ]
Correct: [52.3050515  80.71358509 88.92839627 90.89014223 91.98136341]%
Cross Entropy Loss: [1.3948885  0.5434982  0.28680113 0.239869   0.20033774]
Correct: [54.61010299 81.98871996 90.73075037 92.3737126  93.44041197]%
Cross Entropy Loss: [1.3419678  0.49429786 0.26890284 0.20233306 0.18334925]
Correct: [55.76262874 83.58263855 91.29475233 93.61206474 94.06571849]%
Cross Entropy Loss: [1.2863932  0.45195988 0.25378847 0.18740258 0.13945231]
Correct: [57.27072094 85.43403629 91.98136341 94.13928396 95.81902894]%
Cross Entropy Loss: [1.2747281  0.4539103  0.20959708 0.18265006 0.14377558]
Correct: [58.15350662 85.33594899 93.42815105 94.15154487 95.42667974]%
Cross Entropy Loss: [1.2045016  0.3966868  0.21928594 0.17681737 0.14676303]
Correct: [59.47768514 87.45708681 92.96223639 94.42128494 95.10789603]%
Cross Entropy Loss: [1.2008381  0.37871504 0.19987358 0.16713439 0.10906809]
Correct: [59.40411967 87.80039235 93.63658656 94.4825895  96.46885728]%
Cross Entropy Loss: [1.1866623  0.33645394 0.16413422 0.14532022 0.09666897]
Correct: [60.55664541 89.48013732 95.2795488  95.62285434 97.10642472]%
Cross Entropy Loss: [1.1345832  0.33049986 0.15712306 0.12866771 0.11735864]
Correct: [61.66012751 89.38205002 95.35311427 95.7822462  96.26042178]%
Cross Entropy Loss: [1.0831337  0.30687505 0.14024039 0.11480632 0.0945673 ]
Correct: [63.40117705 90.58361942 96.05198627 96.41981363 96.84894556]%
Cross Entropy Loss: [1.0314867  0.2650021  0.12519914 0.09160795 0.09225935]
Correct: [65.0564002  92.00588524 96.72633644 97.32712114 96.97155468]%
Cross Entropy Loss: [1.0191456  0.261858   0.12771265 0.09649575 0.09357802]
Correct: [64.90926925 92.12849436 96.3830309  96.9838156  96.71407553]%
Cross Entropy Loss: [1.0035199  0.25162822 0.10839103 0.08184652 0.08089903]
Correct: [65.21579205 92.15301618 97.32712114 97.53555665 97.49877391]%
Cross Entropy Loss: [0.96976113 0.25738508 0.10238311 0.08271568 0.05751785]
Correct: [66.99362433 92.1652771  97.58460029 97.43746935 98.25895047]%
Cross Entropy Loss: [0.93922865 0.22113715 0.08706893 0.07554234 0.06556441]
Correct: [67.34919078 93.39136832 98.02599313 97.71947033 98.13634134]%
Cross Entropy Loss: [0.92626643 0.2164144  0.08296338 0.05986106 0.06749475]
Correct: [68.33006376 93.45267288 98.09955861 98.3938205  98.35703776]%
Cross Entropy Loss: [0.9004449  0.22270176 0.09091657 0.09831158 0.07621037]
Correct: [69.4825895  93.28102011 97.86660128 96.92251103 97.40068661]%
Cross Entropy Loss: [0.9043112  0.21227756 0.10489137 0.07800061 0.11881156]
Correct: [68.78371751 93.40362923 96.93477195 97.64590486 96.22363904]%
Cross Entropy Loss: [0.8712321  0.20500828 0.08569799 0.04083661 0.07219958]
Correct: [70.2795488  93.82050025 97.81755763 99.16625797 97.65816577]%
Cross Entropy Loss: [0.8553014  0.17278731 0.07216253 0.03708883 0.0357147 ]
Correct: [70.19372241 95.10789603 98.17312408 99.25208436 99.14173615]%
Cross Entropy Loss: [0.8474399  0.16874377 0.06174877 0.04865002 0.03654258]
Correct: [70.16920059 95.29180971 98.74938695 98.84747425 99.01912702]%
Cross Entropy Loss: [0.8345765  0.17679413 0.0635413  0.08411106 0.09466119]
Correct: [70.91711623 94.81363413 98.61451692 97.57233938 97.96468857]%
Mean abs weights tensor(0.0242, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 46.20%
From layer 1:
Accuracy: 65.77%
From layer 2:
Accuracy: 76.02%
From layer 3:
Accuracy: 73.76%
From layer 4:
Accuracy: 75.09%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 75.12%
From layer 1:
Accuracy: 95.03%
From layer 2:
Accuracy: 99.77%
From layer 3:
Accuracy: 93.39%
From layer 4:
Accuracy: 97.74%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.137089  2.83905   2.5151694 2.031978  1.830148 ]
Correct: [17.10397254 30.3212359  39.41883276 51.29965669 54.68366846]%
Cross Entropy Loss: [2.486774   1.4545557  1.0585564  0.82113963 0.7482235 ]
Correct: [29.97793036 52.31731241 64.57822462 72.57233938 75.76017656]%
Cross Entropy Loss: [2.1698017  1.1517596  0.8305999  0.63961565 0.5538125 ]
Correct: [36.45169201 61.80725846 72.13094654 79.34036292 81.75576263]%
Cross Entropy Loss: [1.9488498  0.99029315 0.6562249  0.49920967 0.44063875]
Correct: [40.25257479 67.61893085 78.05296714 83.76655223 85.78960275]%
Cross Entropy Loss: [1.840665   0.8263306  0.5413463  0.41866308 0.36304823]
Correct: [43.53849926 72.94016675 82.29524277 86.09612555 88.1559588 ]%
Cross Entropy Loss: [1.729859   0.748977   0.4979274  0.38244766 0.36538473]
Correct: [45.72094164 75.25747916 83.60716037 87.39578225 88.49926435]%
Cross Entropy Loss: [1.6232448  0.6763797  0.42611527 0.3156386  0.29797688]
Correct: [48.28347229 77.90583619 86.02256008 89.73761648 90.65718489]%
Cross Entropy Loss: [1.5509827  0.62063855 0.3602563  0.28464812 0.25161597]
Correct: [50.62530652 79.69592938 88.38891614 90.59588033 92.07945071]%
Cross Entropy Loss: [1.4925711  0.5677135  0.33827853 0.25637263 0.22750531]
Correct: [51.58165768 81.3266307  88.81804806 92.01814615 92.53310446]%
Cross Entropy Loss: [1.400738   0.5297122  0.30902562 0.22143154 0.22178225]
Correct: [53.88670917 82.3320255  90.19127023 92.72927906 92.49632173]%
Cross Entropy Loss: [1.3744516  0.4757382  0.2693088  0.20705101 0.20768033]
Correct: [55.1250613  84.14664051 91.28249142 93.50171653 93.42815105]%
Cross Entropy Loss: [1.3562679  0.46579775 0.27394965 0.1867522  0.15391019]
Correct: [56.04462972 84.99264345 91.18440412 94.1760667  95.15693968]%
Cross Entropy Loss: [1.2604469  0.41621557 0.22652529 0.18300518 0.13387336]
Correct: [57.78567925 86.23099559 92.6679745  93.88180481 95.79450711]%
Cross Entropy Loss: [1.2221175  0.40264988 0.21035613 0.16891769 0.1641858 ]
Correct: [59.06081412 87.04021579 93.53849926 94.85041687 94.69102501]%
Cross Entropy Loss: [1.1929152  0.36644056 0.20007657 0.12945664 0.11358865]
Correct: [59.80872977 88.09465424 94.0779794  96.11329083 96.64051005]%
Cross Entropy Loss: [1.2007483  0.35418335 0.17994905 0.12616847 0.09320418]
Correct: [60.37273173 88.79352624 94.45806768 96.17459539 97.09416381]%
Cross Entropy Loss: [1.1082902  0.3305809  0.16217703 0.10457757 0.08944942]
Correct: [62.18734674 89.21039725 95.50024522 97.10642472 97.3884257 ]%
Cross Entropy Loss: [1.0830909  0.30762458 0.14850996 0.10405774 0.07111124]
Correct: [63.08239333 90.04413928 95.48798431 97.03285924 98.06277587]%
Cross Entropy Loss: [1.0538728  0.2823848  0.14326942 0.11936232 0.10511969]
Correct: [64.44335459 91.40510054 95.6841589  96.22363904 96.66503188]%
Cross Entropy Loss: [1.0664315  0.2609587  0.14948817 0.10264071 0.10978639]
Correct: [63.79352624 91.9691025  95.40215792 96.75085826 96.19911721]%
Cross Entropy Loss: [1.0108529  0.23458037 0.11378931 0.08178481 0.11524967]
Correct: [65.12996567 93.4771947  96.8857283  97.75625307 96.73859735]%
Cross Entropy Loss: [0.9756915  0.22943681 0.10149132 0.06685366 0.07484752]
Correct: [66.09857773 92.96223639 97.59686121 98.14860226 97.90338401]%
Cross Entropy Loss: [0.9648824  0.23275249 0.11866336 0.08369678 0.05058756]
Correct: [67.12849436 92.85188818 96.8857283  97.53555665 98.67582148]%
Cross Entropy Loss: [0.9820328  0.23353538 0.09494952 0.07922387 0.07552725]
Correct: [66.45414419 92.63119176 97.53555665 97.486513   97.76851398]%
Cross Entropy Loss: [0.90933836 0.19914022 0.07428931 0.04977357 0.0506533 ]
Correct: [68.575282   94.35998038 98.43060324 98.87199608 98.54095145]%
Cross Entropy Loss: [0.88936853 0.19305217 0.08173259 0.06428906 0.04125523]
Correct: [69.74006866 94.56841589 98.09955861 98.13634134 98.87199608]%
Cross Entropy Loss: [0.87242335 0.19663645 0.07283563 0.08473666 0.13250236]
Correct: [70.10789603 94.23737126 98.36929868 97.57233938 96.33398725]%
Cross Entropy Loss: [0.84897727 0.16397747 0.05372507 0.04592593 0.04084131]
Correct: [70.85581167 95.41441883 99.12947523 98.92103973 98.88425699]%
Cross Entropy Loss: [0.8634181  0.16451418 0.06619841 0.03509476 0.02000098]
Correct: [70.36537518 95.30407062 98.47964689 99.10495341 99.58312898]%
Cross Entropy Loss: [0.8411176  0.15916032 0.0762966  0.0402774  0.02173898]
Correct: [70.80676802 95.41441883 97.9892104  99.06817067 99.57086807]%
Mean abs weights tensor(0.0236, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 43.99%
From layer 1:
Accuracy: 67.89%
From layer 2:
Accuracy: 71.95%
From layer 3:
Accuracy: 74.78%
From layer 4:
Accuracy: 75.44%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 71.95%
From layer 1:
Accuracy: 96.16%
From layer 2:
Accuracy: 97.90%
From layer 3:
Accuracy: 98.86%
From layer 4:
Accuracy: 98.59%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.039971  2.7851174 2.486452  1.9430459 1.8430986]
Correct: [16.84649338 31.81706719 38.8548308  52.74644434 54.45071113]%
Cross Entropy Loss: [2.463136   1.4294655  1.0284486  0.78284824 0.7431666 ]
Correct: [30.8239333  52.78322707 64.98283472 74.03138794 76.63070132]%
Cross Entropy Loss: [2.1937478  1.1570605  0.79711735 0.6259413  0.5709303 ]
Correct: [35.5443845  61.12064738 73.28347229 79.64688573 81.47376165]%
Cross Entropy Loss: [2.0137358  0.9600193  0.6616912  0.53782725 0.4365628 ]
Correct: [38.36439431 67.17753801 77.91809711 82.18489456 85.28690535]%
Cross Entropy Loss: [1.8426241  0.8427452  0.5743095  0.4306796  0.40272054]
Correct: [42.15301618 71.40755272 80.99558607 85.66699362 86.75821481]%
Cross Entropy Loss: [1.7311393  0.77135646 0.49877268 0.4096734  0.3651322 ]
Correct: [45.59833252 74.46051986 83.39872487 86.89308485 88.20500245]%
Cross Entropy Loss: [1.6484203  0.66545975 0.4068076  0.3251174  0.28365132]
Correct: [48.12408043 77.42766062 86.72143207 89.13683178 90.80431584]%
Cross Entropy Loss: [1.5442686  0.5987709  0.37030056 0.28895703 0.23618457]
Correct: [50.41687102 79.63462482 87.95978421 90.53457577 91.95684159]%
Cross Entropy Loss: [1.4916548  0.53978664 0.3424108  0.25369436 0.21247785]
Correct: [51.54487494 82.46689554 88.70769985 91.89553703 93.13388916]%
Cross Entropy Loss: [1.4133347  0.509476   0.32075867 0.24944818 0.22278088]
Correct: [54.00931829 83.39872487 89.34526729 92.01814615 92.69249632]%
Cross Entropy Loss: [1.3830287  0.5001519  0.2898191  0.23217046 0.195896  ]
Correct: [54.95340853 83.16576753 90.25257479 92.48406081 93.08484551]%
Cross Entropy Loss: [1.2901641  0.43293515 0.24506919 0.18838742 0.14955066]
Correct: [57.25846003 85.65473271 92.28788622 94.24963217 95.31633154]%
Cross Entropy Loss: [1.2372777  0.39047337 0.22973025 0.166224   0.16230658]
Correct: [58.52133399 87.42030407 92.83962727 94.66650319 94.8749387 ]%
Cross Entropy Loss: [1.2153274  0.38333535 0.20772064 0.13805215 0.1135439 ]
Correct: [59.26924963 87.30995586 93.52623835 95.90485532 96.55468367]%
Cross Entropy Loss: [1.1932608  0.35944295 0.17725739 0.12976323 0.10778987]
Correct: [60.16429622 88.22952428 94.86267778 96.02746444 96.61598823]%
Cross Entropy Loss: [1.1951419  0.34799445 0.18141316 0.16180287 0.13921611]
Correct: [60.43403629 89.28396273 94.60519863 94.88719961 95.67189799]%
Cross Entropy Loss: [1.1390563  0.32985952 0.15548478 0.12224641 0.14337325]
Correct: [61.63560569 89.07552722 95.57381069 96.10102992 95.32859245]%
Cross Entropy Loss: [1.0992781  0.29267445 0.14065517 0.11397361 0.09010758]
Correct: [62.64100049 90.75527219 96.2849436  96.59146641 97.32712114]%
Cross Entropy Loss: [1.0419786  0.2911895  0.15330629 0.11333431 0.07719459]
Correct: [64.82344286 91.17214321 95.69641981 96.6895537  97.57233938]%
Cross Entropy Loss: [1.0367918  0.2731346  0.13755295 0.1101648  0.05505923]
Correct: [64.43109367 91.73614517 95.95389897 96.71407553 98.62677783]%
Cross Entropy Loss: [0.9959506  0.24609745 0.11540638 0.0912552  0.07211863]
Correct: [65.54683668 92.39823443 96.9838156  97.33938205 97.82981854]%
Cross Entropy Loss: [0.98092103 0.25717545 0.11006273 0.08322906 0.06834142]
Correct: [66.31927415 91.93231976 97.00833742 97.6949485  97.80529671]%
Cross Entropy Loss: [0.9605376  0.23756683 0.10747038 0.09469084 0.08080539]
Correct: [67.23884257 92.76606179 97.13094654 96.91025012 97.33938205]%
Cross Entropy Loss: [0.92429084 0.23098421 0.10135122 0.09095821 0.06064046]
Correct: [68.41589014 92.74153997 97.31486023 97.21677293 98.14860226]%
Cross Entropy Loss: [0.9193822  0.19620693 0.07601243 0.06427409 0.06440874]
Correct: [68.66110839 94.42128494 98.28347229 98.27121138 98.25895047]%
Cross Entropy Loss: [0.87003124 0.1799619  0.07674473 0.05839045 0.03519527]
Correct: [70.10789603 94.81363413 98.17312408 98.52869053 99.17851888]%
Cross Entropy Loss: [0.88959146 0.20246783 0.09782011 0.0483296  0.02866284]
Correct: [69.49485042 93.82050025 97.52329573 98.72486513 99.37469348]%
Cross Entropy Loss: [0.8571729  0.15877236 0.07271554 0.03658586 0.02733131]
Correct: [70.57381069 95.85581167 98.35703776 99.21530162 99.46051986]%
Cross Entropy Loss: [0.85579056 0.18011901 0.06780642 0.05525568 0.03207569]
Correct: [70.41441883 94.53163315 98.52869053 98.5899951  99.23982344]%
Cross Entropy Loss: [0.8311276  0.15854472 0.05081272 0.10051975 0.02865989]
Correct: [71.11329083 95.31633154 99.23982344 97.09416381 99.38695439]%
Mean abs weights tensor(0.0232, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 45.89%
From layer 1:
Accuracy: 70.05%
From layer 2:
Accuracy: 74.03%
From layer 3:
Accuracy: 75.80%
From layer 4:
Accuracy: 77.25%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 73.75%
From layer 1:
Accuracy: 97.25%
From layer 2:
Accuracy: 99.19%
From layer 3:
Accuracy: 98.47%
From layer 4:
Accuracy: 99.63%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [4.184788  3.1587377 2.7761364 2.1190047 1.5790383]
Correct: [17.58214811 28.02844532 39.1613536  52.24374693 57.99411476]%
Cross Entropy Loss: [2.5372136 1.4718604 0.9933521 0.7935557 0.7099757]
Correct: [29.46297205 52.2069642  66.17214321 74.27660618 76.5080922 ]%
Cross Entropy Loss: [2.1657991  1.1554668  0.773986   0.62683105 0.55019426]
Correct: [35.50760177 61.61108386 73.83521334 79.49975478 82.47915645]%
Cross Entropy Loss: [2.04431    1.0059303  0.65964514 0.53714925 0.4942175 ]
Correct: [39.1613536  66.4664051  77.83227072 82.03776361 84.61255517]%
Cross Entropy Loss: [1.8531626  0.84471947 0.53601384 0.4364727  0.37458396]
Correct: [43.36684649 72.03285924 82.30750368 85.49534085 87.67778323]%
Cross Entropy Loss: [1.7145026  0.74817437 0.46584633 0.37496412 0.3416456 ]
Correct: [45.64737616 74.9019127  84.56351153 87.66552231 88.79352624]%
Cross Entropy Loss: [1.632453   0.67209244 0.4139974  0.31904843 0.284209  ]
Correct: [47.70720942 77.53800883 86.02256008 89.2594409  90.70622854]%
Cross Entropy Loss: [1.5207928  0.6020703  0.3498041  0.26755515 0.24204047]
Correct: [50.83374203 79.90436488 88.47474252 91.44188328 91.95684159]%
Cross Entropy Loss: [1.4819714  0.55549556 0.33333194 0.27373856 0.22522403]
Correct: [51.33643943 81.31436979 88.70769985 90.90240314 92.63119176]%
Cross Entropy Loss: [1.3934038  0.50689507 0.30238512 0.25241745 0.228473  ]
Correct: [53.86218735 83.44776851 90.0564002  92.26336439 92.38597352]%
Cross Entropy Loss: [1.3720877  0.4691646  0.25921592 0.22906601 0.20239407]
Correct: [54.54879843 84.62481609 91.4664051  92.65571359 93.68563021]%
Cross Entropy Loss: [1.2845447  0.43691215 0.24668133 0.21428256 0.2010653 ]
Correct: [57.31976459 85.89995096 92.10397254 92.9744973  93.73467386]%
Cross Entropy Loss: [1.280619   0.405934   0.22534233 0.15904456 0.14694521]
Correct: [58.00637567 86.81951937 92.76606179 95.14467876 95.41441883]%
Cross Entropy Loss: [1.2268513  0.40257344 0.20199575 0.16112684 0.1530347 ]
Correct: [59.35507602 86.8440412  94.16380579 95.21824424 95.16920059]%
Cross Entropy Loss: [1.1595879  0.34152353 0.1825932  0.13547249 0.12401041]
Correct: [60.16429622 89.22265817 94.31093673 95.79450711 95.76998529]%
Cross Entropy Loss: [1.1740674  0.34141546 0.16432494 0.10791507 0.10024589]
Correct: [60.98577734 89.38205002 95.23050515 96.97155468 96.95929377]%
Cross Entropy Loss: [1.1018981  0.33131233 0.16468696 0.11982096 0.10149852]
Correct: [62.49386954 89.49239823 94.88719961 96.2849436  96.77538009]%
Cross Entropy Loss: [1.0698056  0.30210882 0.14659934 0.10793016 0.07807189]
Correct: [64.01422266 90.98822952 95.84355076 96.77538009 97.67042668]%
Cross Entropy Loss: [1.0390097  0.27483198 0.14629385 0.11470152 0.1244206 ]
Correct: [64.30848455 91.52770966 95.54928887 96.59146641 96.11329083]%
Cross Entropy Loss: [1.0666074  0.2754702  0.13488962 0.08425768 0.10074703]
Correct: [63.67091712 91.41736145 96.03972536 97.57233938 96.46885728]%
Cross Entropy Loss: [0.989197   0.26740304 0.12125856 0.0805189  0.05789385]
Correct: [66.19666503 91.27023051 96.50564002 97.7930358  98.34477685]%
Cross Entropy Loss: [0.97220165 0.23401415 0.09364659 0.06646873 0.04612004]
Correct: [66.3683178  92.75380088 97.71947033 98.40608141 98.81069152]%
Cross Entropy Loss: [0.9775605  0.22502445 0.09544985 0.08029222 0.05041957]
Correct: [66.20892594 93.36684649 97.71947033 97.44973026 98.60225601]%
Cross Entropy Loss: [0.9235293  0.2295003  0.11813275 0.13883506 0.11215428]
Correct: [68.50171653 93.23197646 96.94703286 96.00294262 97.36390387]%
Cross Entropy Loss: [0.89851034 0.2234187  0.13292783 0.07578858 0.13245551]
Correct: [69.05345758 92.91319274 95.86807258 97.63364394 95.62285434]%
Cross Entropy Loss: [0.8868577  0.2050404  0.10132988 0.05474962 0.05260646]
Correct: [69.71554684 93.68563021 96.86120647 98.62677783 98.35703776]%
Cross Entropy Loss: [0.8978951  0.17945962 0.06670823 0.03812624 0.02593508]
Correct: [69.23737126 95.09563512 98.67582148 99.27660618 99.42373713]%
Cross Entropy Loss: [0.8495217  0.17978917 0.06059927 0.03778436 0.02211628]
Correct: [70.90485532 94.7768514  98.81069152 99.23982344 99.71799902]%
Cross Entropy Loss: [0.8340609  0.1867944  0.05535805 0.04479715 0.02870846]
Correct: [71.1868563  94.29867582 98.95782246 98.83521334 99.32564983]%
Cross Entropy Loss: [0.809062   0.14747775 0.05688252 0.03789786 0.02827317]
Correct: [71.79990191 96.10102992 98.77390878 99.19077979 99.37469348]%
Mean abs weights tensor(0.0237, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 43.37%
From layer 1:
Accuracy: 66.25%
From layer 2:
Accuracy: 73.54%
From layer 3:
Accuracy: 74.69%
From layer 4:
Accuracy: 75.93%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 71.44%
From layer 1:
Accuracy: 95.94%
From layer 2:
Accuracy: 98.75%
From layer 3:
Accuracy: 99.41%
From layer 4:
Accuracy: 97.98%
  0%|          | 0/30 [00:00<?, ?it/s]
Cross Entropy Loss: [3.8450236 2.893015  2.5818913 1.950869  1.7146541]
Correct: [18.19519372 30.84845513 39.89700834 50.79695929 56.08141246]%
Cross Entropy Loss: [2.4324782 1.4438946 1.0239114 0.7891099 0.7103934]
Correct: [30.39480137 52.84453163 65.24031388 73.41834232 76.6061795 ]%
Cross Entropy Loss: [2.1515372  1.1023425  0.7595511  0.58845913 0.53165674]
Correct: [35.71603727 62.97204512 74.14173615 80.34575772 82.60176557]%
Cross Entropy Loss: [1.9809672 0.9860559 0.690472  0.5189461 0.4655997]
Correct: [39.72535557 67.55762629 77.39087788 83.07994115 84.75968612]%
Cross Entropy Loss: [1.8339075  0.83092946 0.52102846 0.40555686 0.3964398 ]
Correct: [42.11623345 72.30259931 82.81020108 86.67238843 87.11378127]%
Cross Entropy Loss: [1.7236562  0.7560297  0.46359384 0.41708466 0.34153315]
Correct: [45.31633154 75.1961746  84.9435998  86.61108386 89.39431094]%
Cross Entropy Loss: [1.6499207  0.680678   0.42241886 0.35685247 0.3319593 ]
Correct: [47.75625307 77.59931339 85.65473271 88.24178519 89.14909269]%
Cross Entropy Loss: [1.5579765  0.62613094 0.39922643 0.306875   0.27899   ]
Correct: [49.91417361 79.45071113 86.99117214 89.99509564 90.9637077 ]%
Cross Entropy Loss: [1.4878823  0.53445643 0.3487977  0.24159333 0.23848248]
Correct: [51.36096126 82.36880824 88.57282982 92.05492889 92.52084355]%
Cross Entropy Loss: [1.4325871  0.5162798  0.3049619  0.23902264 0.19422716]
Correct: [53.42079451 82.66307013 90.08092202 92.27562531 93.83276116]%
Cross Entropy Loss: [1.3437029  0.45763406 0.25552455 0.19834314 0.15860814]
Correct: [55.75036783 85.29916626 91.93231976 93.63658656 94.83815596]%
Cross Entropy Loss: [1.3352     0.47081318 0.27910143 0.23056087 0.18670449]
Correct: [55.46836685 84.29377146 91.08631682 92.81510544 94.22511035]%
Cross Entropy Loss: [1.2473556  0.40079    0.2130526  0.17290083 0.15900648]
Correct: [57.82246199 86.96665032 93.73467386 94.8749387  95.07111329]%
Cross Entropy Loss: [1.235354   0.39881483 0.20369032 0.1417997  0.12176415]
Correct: [58.59489946 87.28543404 93.93084846 95.34085336 96.37076999]%
Cross Entropy Loss: [1.241568   0.39221078 0.2009701  0.15674618 0.13668656]
Correct: [59.01177048 87.150564   93.91858754 95.14467876 95.67189799]%
Cross Entropy Loss: [1.1979336  0.3370392  0.18356739 0.13470073 0.10851642]
Correct: [59.83325159 89.13683178 94.40902403 95.90485532 96.86120647]%
Cross Entropy Loss: [1.1245631  0.31244665 0.16034669 0.11932953 0.10765743]
Correct: [61.36586562 90.01961746 95.35311427 96.27268269 96.53016184]%
Cross Entropy Loss: [1.0990295  0.31648797 0.14210427 0.12133203 0.102965  ]
Correct: [63.27856793 90.28935753 96.03972536 96.21137813 96.71407553]%
Cross Entropy Loss: [1.0387309  0.29575467 0.12608631 0.10486513 0.07526179]
Correct: [64.68857283 90.33840118 96.61598823 96.787641   97.73173124]%
Cross Entropy Loss: [1.0274373  0.27018297 0.12596369 0.09948713 0.10568168]
Correct: [65.37518391 91.30701324 96.49337911 97.05738107 96.61598823]%
Cross Entropy Loss: [1.0123237  0.2550003  0.11294316 0.06896314 0.0644097 ]
Correct: [65.65718489 92.11623345 97.16772928 98.27121138 97.94016675]%
Cross Entropy Loss: [0.98168045 0.24600674 0.10362274 0.10618742 0.06248716]
Correct: [66.00049044 92.1652771  97.52329573 96.56694458 98.2957332 ]%
Cross Entropy Loss: [0.96057165 0.23113486 0.10995624 0.09623755 0.07254453]
Correct: [66.71162334 93.08484551 97.02059833 96.9838156  97.95242766]%
Cross Entropy Loss: [0.9425977  0.22046018 0.09850635 0.09374876 0.07084412]
Correct: [67.88867092 93.35458558 97.41294752 97.08190289 97.95242766]%
Cross Entropy Loss: [0.9229381  0.22128108 0.09750687 0.07065929 0.08706443]
Correct: [68.46493379 93.28102011 97.46199117 97.84207945 97.17999019]%
Cross Entropy Loss: [0.9139848  0.19916731 0.09013063 0.05390324 0.03303539]
Correct: [68.23197646 93.90632663 97.47425208 98.46738597 99.17851888]%
Cross Entropy Loss: [0.8758341  0.17737143 0.06296918 0.0370774  0.04109732]
Correct: [69.85041687 94.7768514  98.78616969 99.28886709 98.9946052 ]%
Cross Entropy Loss: [0.8290601  0.17428882 0.06098385 0.05355379 0.03314197]
Correct: [71.40755272 95.05885238 98.77390878 98.44286415 99.21530162]%
Cross Entropy Loss: [0.84602153 0.15500756 0.06126471 0.03661681 0.04568071]
Correct: [71.13781265 95.57381069 98.78616969 99.26434527 98.76164787]%
Cross Entropy Loss: [0.84241277 0.15862377 0.05996899 0.07363057 0.02933183]
Correct: [70.76998529 95.94163806 98.93330064 97.73173124 99.32564983]%
Mean abs weights tensor(0.0237, grad_fn=<MeanBackward0>)
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 44.92%
From layer 1:
Accuracy: 69.92%
From layer 2:
Accuracy: 75.35%
From layer 3:
Accuracy: 74.91%
From layer 4:
Accuracy: 77.43%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 71.13%
From layer 1:
Accuracy: 97.35%
From layer 2:
Accuracy: 99.55%
From layer 3:
Accuracy: 98.52%
From layer 4:
Accuracy: 99.88%
Fast Classifier Mean Test Accuracy: tensor([45.0221, 67.9240, 73.9443, 75.3666, 76.6343]), Std: tensor([0.8696, 1.4853, 1.5259, 1.4153, 1.4726])
Fast Classifier Mean Train Accuracy: tensor([72.8678, 96.1550, 98.9186, 98.4282, 98.4343]), Std: tensor([1.4359, 0.8877, 0.6875, 1.9982, 1.5555])
In [ ]:
# run all checkpoints
if not args.augment: 
    last_ckp = 20
    test_accs_ckpt = []
    train_accs_ckpt = []
    epochs = []
    while True:
        print(f'Checkpoint {last_ckp}')
        SNN_ckp = EchoSpike(args.n_inputs, args.n_hidden, beta=args.beta, device=device, recurrency_type=args.recurrency_type, online=args.online).to(device)
        try:
            print(model_name[:-3] + f'_epoch{last_ckp}.pt')
            state_dict = torch.load(model_name[:-3] + f'_epoch{last_ckp}.pt', map_location=device)
            # state_dict = {key.replace('clapp', 'layers'):value for key, value in state_dict.items()}
            # torch.save(state_dict, model_name + f'_epoch{last_ckp}.pt')
            SNN_ckp.load_state_dict(state_dict)
        except:
            if last_ckp > 1500:
                break
            else:
                last_ckp += 20
                continue
        epochs.append(last_ckp)
        last_ckp += 20
        snn_samples, targets = get_samples(SNN_ckp, train_loader, args.n_hidden, device)
        cat = True
        with torch.no_grad():
            out_projs, acc, losses_out = train_out_proj_fast(SNN_ckp, args, 60, 60, snn_samples, train_loader.y, cat=cat, lr=1e-4, weight_decay=1)
        test_accs_ckpt.append(torch.tensor(get_accuracy(SNN_ckp, out_projs, test_loader, device, cat=cat)[0]))
        train_accs_ckpt.append(torch.tensor(get_accuracy(SNN_ckp, out_projs, train_loader, device, cat=cat)[0])) 
    test_accs_ckpt = torch.stack(test_accs_ckpt)
    train_accs_ckpt = torch.stack(train_accs_ckpt)
    # save the results
    torch.save(torch.stack([torch.tensor(tac) for tac in test_accs_ckpt]), model_name[:-3]+'_test_accs_ckpt.pt')
    torch.save(torch.stack([torch.tensor(tac) for tac in train_accs_ckpt]), model_name[:-3]+'_train_accs_ckpt.pt')
In [ ]:
# plot train and test accuracy over time
plt.figure()
for i in range(test_accs_ckpt.shape[-1]):
    if i == 0:
        plt.plot(epochs, 100*test_accs_ckpt[:,i], color=color_list[i], label=f'Directly from inputs')
        plt.plot(epochs, 100*train_accs_ckpt[:,i], color=color_list[i], linestyle='--')
    else:
        plt.plot(epochs, 100*test_accs_ckpt[:,i], color=color_list[i], label=f'Layer {i}')
        plt.plot(epochs, 100*train_accs_ckpt[:,i], color=color_list[i], linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f4d60de9850>
In [ ]:
if not args.augment:
    # if snn_samples already exist, don't recompute
    try:
        snn_samples
    except:
        snn_samples, targets = get_samples(SNN, train_loader, args.n_hidden, device)
    cat = True
    out_projs_closed = train_out_proj_closed_form(args, snn_samples, targets, cat=cat)
    test_acc_closed, _ = get_accuracy(SNN, out_projs_closed, test_loader, device, cat=cat)
    train_acc_closed, _ = get_accuracy(SNN, out_projs_closed, train_loader, device, cat=cat)

    # grouped Bar plot the Accuracies of the different layers both during training and testing
    sns.set_theme(style="whitegrid")
    labels = ['From Inputs Directly', *[f'Until Layer {i+1}' for i in range(len(SNN.layers))]]
    if not cat:
        labels = ['From Inputs Directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
    x = np.arange(len(labels))  # the label locations
    width = 0.35  # the width of the bars
    fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, 100*torch.tensor(test_acc_closed), width, label='Test Accuracy', color=color_list[0])
    rects2 = ax.bar(x + width/2, 100*torch.tensor(train_acc_closed), width, label='Train Accuracy', color=color_list[1])
    # remove horizontal lines and spines
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.xaxis.grid(False)
    plt.xticks(np.arange(len(out_projs_closed)), labels, rotation=45)
    plt.legend()
    plt.ylabel('Accuracy [%]')
    plt.ylim([25, 100])
(20, 700) 0.010724677 -0.010629931
(20, 1150) 0.056771163 -0.039086737
(20, 1600) 0.4956497 -0.2596547
(20, 2050) 0.44425982 -0.25678265
(20, 2500) 0.8378215 -0.91546047
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 28.89%
From layer 1:
Accuracy: 55.79%
From layer 2:
Accuracy: 67.67%
From layer 3:
Accuracy: 74.20%
From layer 4:
Accuracy: 75.88%
  0%|          | 0/128 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 66.96%
From layer 1:
Accuracy: 93.44%
From layer 2:
Accuracy: 97.95%
From layer 3:
Accuracy: 99.13%
From layer 4:
Accuracy: 99.50%
In [ ]:
if args.augment:
    # plot some training characteristics
    print(f'Accuracy of last quarter: {100*acc[-len(acc)//4:].mean(axis=0)}%')
    plt.figure()
    for i in range(len(acc[0])):
        plt.plot(np.asarray(acc)[:,i]*100, color=color_list[i])
    plt.ylabel('Accuracy [%]')
    plt.xlabel('Training Step [x500]')
    labels = ['From Inputs directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
    plt.legend(labels)
    plt.ylim([65, 95])
    plt.figure()
    print(losses_out.shape)
    for i in range(losses_out.shape[1]):
        plt.plot(np.arange(len(losses_out))/len(train_loader), savgol_filter(losses_out[:,i], 99, 1), label=labels[i], color=color_list[i])
    plt.ylabel('Cross Entropy Loss')
    plt.xlabel('Training Step')
    plt.ylim([0.15, 1.0])
    plt.legend();
Accuracy of last quarter: [43.03571429 77.07000589 87.01358909 89.62446036 90.12460754]%
torch.Size([25488, 5])

Get output projection Accuracy on test set¶

In [ ]:
test_acc, pred_matrix = get_accuracy(SNN, out_projs, test_loader, device, cat=cat)
plt.figure()
plt.plot(100*np.asarray(test_acc))
plt.ylabel('Accuracy [%]')
plt.xlabel('Layer')

plt.figure()
plt.imshow(pred_matrix, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(args.n_outputs)])
plt.yticks([i for i in range(args.n_outputs)])
plt.colorbar();
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 46.64%
From layer 1:
Accuracy: 73.14%
From layer 2:
Accuracy: 78.80%
From layer 3:
Accuracy: 80.17%
From layer 4:
Accuracy: 80.96%
In [ ]:
from utils import get_accuracy
if args.augment:
    train_acc, _ = get_accuracy(SNN, out_projs, train_loader, device, cat=cat) 
else:
    test_acc = torch.mean(test_accs, dim=0)
    print(test_acc)
    train_acc = torch.mean(train_accs, dim=0)
# grouped Bar plot the Accuracies of the different layers both during training and testing
sns.set_theme(style="whitegrid")
labels = ['From Inputs Directly', *[f'Until Layer {i+1}' for i in range(len(SNN.layers))]]
if not cat:
    labels = ['From Inputs Directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
x = np.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars
fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, 100*torch.tensor(test_acc), width, label='Test Accuracy', color=color_list[0])
rects2 = ax.bar(x + width/2, 100*torch.tensor(train_acc), width, label='Train Accuracy', color=color_list[1])
if not args.augment:
    ax.errorbar(x - width/2, 100*test_acc, yerr=100*torch.std(test_accs, dim=0), fmt='none', capsize=6, color=color_list[3])
    ax.errorbar(x + width/2, 100*train_acc, yerr=100*torch.std(train_accs, dim=0), fmt='none', capsize=6, color=color_list[3])
# remove horizontal lines and spines
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.grid(False)
plt.xticks(np.arange(len(out_projs)), labels, rotation=45)
plt.legend()
plt.ylabel('Accuracy [%]')
plt.ylim([25, 100])
#plt.title('SHD Accuracy');
tensor([0.4502, 0.6792, 0.7394, 0.7537, 0.7663])
/tmp/ipykernel_901041/1496231988.py:16: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  rects1 = ax.bar(x - width/2, 100*torch.tensor(test_acc), width, label='Test Accuracy', color=color_list[0])
/tmp/ipykernel_901041/1496231988.py:17: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  rects2 = ax.bar(x + width/2, 100*torch.tensor(train_acc), width, label='Train Accuracy', color=color_list[1])
Out[ ]:
(25.0, 100.0)

Few Shot Learning (discontinued)¶

In [ ]:
# Randomly select k sample of each class and save the spiking activity
n_outputs = 20
n_repeats = 1
k = 20
fewshot_accuracies = torch.zeros((n_repeats, len(SNN.layers)))
for n in range(n_repeats):
    SNN.reset(0)
    one_shot_samples = torch.zeros(n_outputs, n_time_bins, n_inputs)
    one_shot_spks = torch.zeros(n_outputs, len(SNN.layers), n_hidden[0])
    for i in trange(n_outputs):
        for j in range(k):
            img, _ = train_loader.next_item(i, contrastive=False)
            one_shot_samples[i] = img.squeeze()
            for t in range(n_time_bins):
                logits, mem_his, clapp_loss = SNN(img[t].float(), 0) 
                one_shot_spks[i] += torch.stack(logits).squeeze()

    def metric(spk, one_shot):
        dists = torch.zeros(spk.shape[0], args.n_outputs)
        for i in range(args.n_outputs):
            one_shot_i = one_shot[i] / one_shot[i].sum()
            dists[:, i] = torch.einsum('bi, i->b' , spk, one_shot_i)
        return dists

    def get_predictions(spks):
        preds = torch.zeros(len(spks), spks[0].shape[0])
        # for each layer get the prediction
        for i in range(len(spks)):
            dists = metric(spks[i], one_shot_spks[:,i])
            preds[i] = dists.argmax(axis=-1)
        return preds
    dataset = test_loader
    batch = int(len(dataset)/100)
    correct_oneshot = torch.zeros(len(SNN.layers))
    SNN.eval()
    pred_matrix_oneshot = torch.zeros(n_outputs, n_outputs)
    for idx in trange(0, len(dataset), batch):
        SNN.reset(0)
        inp, target = dataset.x[idx:idx+batch], dataset.y[idx:idx+batch]
        logits = torch.zeros(len(SNN.layers), inp.shape[0], n_hidden[0])
        for step in range(inp.shape[1]):
            data_step = inp[:,step].float().to(device)
            spk_step, _, _ = SNN(data_step, 0)
            logits += torch.stack(spk_step)
        preds = get_predictions(logits)
        for i in range(preds.shape[0]):
            correct_oneshot[i] += int((preds[i] == target).sum())
        # for the last layer create the prediction matrix
        for j in range(preds.shape[1]):
            pred_matrix_oneshot[int(target[j]), int(preds[-1, j])] += 1
    correct_oneshot /= len(dataset)
    for i in range(len(SNN.layers)):
        print(f'From layer {i+1}:')
        print(f'Accuracy: {100*correct_oneshot[i]:.2f}%')
        fewshot_accuracies[n, i] = correct_oneshot[i]
    plt.imshow(pred_matrix_oneshot, origin='lower')
    plt.title('Prediction Matrix for the final layer')
    plt.xlabel('Prediction')
    plt.ylabel('Target')
    plt.xticks([i for i in range(n_outputs)])
    plt.yticks([i for i in range(n_outputs)])
    plt.colorbar();
    plt.figure()
In [ ]:
# Boxplot of the accuracies
plt.figure()
sns.set_style("whitegrid")
g = sns.boxplot(data=fewshot_accuracies*100)
# remove left spines
sns.despine(left=True)
plt.xticks(np.arange(len(SNN.layers)), [f'Layer {i+1}' for i in range(len(SNN.layers))])
plt.ylabel('Few-Shot Test Accuracy [%]')
plt.ylim([0, 100])
print(f'Average Accuracy: {100*fewshot_accuracies.mean(axis=0)}%')
print(f'Maximum Accuracy: {fewshot_accuracies.max(axis=0)}%')